{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# 使用微调的方法让模型对新闻分类更加准确\n",
    "\n",
    "在本实操手册中，开发者将使用ChatGLM3-6B base模型，对新闻分类数据集进行微调，并使用微调后的模型进行推理。\n",
    "本操作手册使用公开数据集，数据集中包含了新闻标题和新闻关键词，开发者需要根据这些信息，将新闻分类到15个类别中的一个。\n",
    "为了体现模型高效的学习能力，以及让用户更快的学习本手册，我们只使用了数据集中的一小部分数据，实际上，数据集中包含了超过40万条新闻数据。\n",
    "\n",
    "## 硬件要求\n",
    "本实践手册需要使用 FP16 精度的模型进行推理，因此，我们推荐使用至少 16GB 显存的 英伟达 GPU 来完成本实践手册。\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "首先，我们将原始的数据集格式转换为用于微调的`jsonl`格式，以方便进行微调。"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "# 路径可以根据实际情况修改\n",
    "\n",
    "input_file_path = 'data/toutiao_cat_data_example.txt'\n",
    "output_file_path = 'data/toutiao_cat_data_example.jsonl'\n",
    "\n",
    "# 提示词\n",
    "prompt_prefix = \"\"\"\n",
    "你是一个专业的新闻专家，请根据我提供的新闻信息，包括新闻标题，新闻关键词等信息，你要对每一行新闻类别进行分类并告诉我结果，不要返回其他信息和多于的文字，这些类别是:\n",
    "news_story\n",
    "news_culture\n",
    "news_sports\n",
    "news_finance\n",
    "news_house\n",
    "news_car\n",
    "news_edu\n",
    "news_tech\n",
    "news_military\n",
    "news_travel\n",
    "news_world\n",
    "stock\n",
    "news_agriculture\n",
    "news_game\n",
    "请选择其中一个类别并返回，你只要返回类别的名称，不要返回其他信息。让我们开始吧:\n",
    "\"\"\"\n",
    "\n",
    "# 分类代码和名称的映射\n",
    "category_map = {\n",
    "    \"100\": \"news_story\",\n",
    "    \"101\": \"news_culture\",\n",
    "    \"102\": \"news_entertainment\",\n",
    "    \"103\": \"news_sports\",\n",
    "    \"104\": \"news_finance\",\n",
    "    \"106\": \"news_house\",\n",
    "    \"107\": \"news_car\",\n",
    "    \"108\": \"news_edu\",\n",
    "    \"109\": \"news_tech\",\n",
    "    \"110\": \"news_military\",\n",
    "    \"112\": \"news_travel\",\n",
    "    \"113\": \"news_world\",\n",
    "    \"114\": \"stock\",\n",
    "    \"115\": \"news_agriculture\",\n",
    "    \"116\": \"news_game\"\n",
    "}\n",
    "\n",
    "def process_line(line):\n",
    "    # 分割每行数据\n",
    "    parts = line.strip().split('_!_')\n",
    "    if len(parts) != 5:\n",
    "        return None\n",
    "\n",
    "    # 提取所需字段\n",
    "    _, category_code, _, news_title, news_keywords = parts\n",
    "\n",
    "    # 构造 JSON 对象\n",
    "    news_title = news_title if news_title else \"无\"\n",
    "    news_keywords = news_keywords if news_keywords else \"无\"\n",
    "    json_obj = {\n",
    "        \"context\": prompt_prefix + f\"新闻标题: {news_title}\\n 新闻关键词: {news_keywords}\\n\",\n",
    "        \"target\": category_map.get(category_code, \"无\")\n",
    "    }\n",
    "    return json_obj\n",
    "\n",
    "def convert_to_jsonl(input_path, output_path):\n",
    "    with open(input_path, 'r', encoding='utf-8') as infile, \\\n",
    "         open(output_path, 'w', encoding='utf-8') as outfile:\n",
    "        for line in infile:\n",
    "            json_obj = process_line(line)\n",
    "            if json_obj:\n",
    "                json_line = json.dumps(json_obj, ensure_ascii=False)\n",
    "                outfile.write(json_line + '\\n')\n",
    "\n",
    "# 运行转换函数\n",
    "convert_to_jsonl(input_file_path, output_file_path)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 使用没有微调的模型进行推理\n",
    "首先，我们先试用原本的模基座模型进行推理，并查看效果。"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "PROMPT = \"\"\"\n",
    "你是一个专业的新闻专家，请根据我提供的新闻信息，包括新闻标题，新闻关键词等信息，你要对每一行新闻类别进行分类并告诉我结果，不要返回其他信息和多于的文字，这些类别是:\n",
    "news_story\n",
    "news_culture\n",
    "news_sports\n",
    "news_finance\n",
    "news_house\n",
    "news_car\n",
    "news_edu\n",
    "news_tech\n",
    "news_military\n",
    "news_travel\n",
    "news_world\n",
    "stock\n",
    "news_agriculture\n",
    "news_game\n",
    "请选择其中一个类别并返回，你只要返回类别的名称，不要返回其他信息。让我们开始吧:\n",
    "新闻标题：华为手机扛下敌人子弹，是什么技术让其在战争中大放异彩？\n",
    "新闻关键词: 华为手机\n",
    "\"\"\""
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from transformers import AutoModel, AutoTokenizer\n",
    "import  torch\n",
    "# 参数设置\n",
    "model_path = \"/share/home/zyx/Models/chatglm3-6b-base\"\n",
    "tokenizer_path = model_path\n",
    "device = \"cuda\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)\n",
    "model = AutoModel.from_pretrained(model_path, load_in_8bit=False, trust_remote_code=True).to(device)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "max_new_tokens = 1024\n",
    "temperature = 0.4\n",
    "top_p = 0.9\n",
    "inputs = tokenizer(PROMPT, return_tensors=\"pt\").to(device)\n",
    "response = model.generate(input_ids=inputs[\"input_ids\"],max_new_tokens=max_new_tokens,temperature=temperature,top_p=top_p,do_sample=True)\n",
    "response = response[0, inputs[\"input_ids\"].shape[-1]:]\n",
    "origin_answer = tokenizer.decode(response, skip_special_tokens=True)\n",
    "origin_answer"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "del model\n",
    "torch.cuda.empty_cache()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "我们可以发现，模型并没有正确的对新闻进行分类。这可能是由模型训练阶段的数据导致的问题。那么，我们通过微调这个模型，能不能实现更好的效果呢？"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 使用官方的微调脚本进行微调\n",
    "\n",
    "在完成数据的切割之后，我们需要按照官方提供好的方式进行微调。我们使用的模型为`chatglm3-6b-base`基座模型，该模型相对于Chat模型，更容易上手微调，且更符合本章节的应用场景。\n",
    "\n",
    "我们将对应的参数设置好后，就可以直接执行下面的代码进行微调。该代码使用`Lora`方案进行微调，成本相较于全参微调大幅度降低。\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2023-12-23 13:07:01,067] torch.distributed.run: [WARNING] master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.\r\n",
      "12/23/2023 13:07:03 - WARNING - __main__ - Process rank: 0, device: cpu, n_gpu: 0distributed training: True, 16-bits training: False\r\n",
      "12/23/2023 13:07:03 - INFO - __main__ - Training/evaluation parameters Seq2SeqTrainingArguments(\r\n",
      "_n_gpu=0,\r\n",
      "adafactor=False,\r\n",
      "adam_beta1=0.9,\r\n",
      "adam_beta2=0.999,\r\n",
      "adam_epsilon=1e-08,\r\n",
      "auto_find_batch_size=False,\r\n",
      "bf16=False,\r\n",
      "bf16_full_eval=False,\r\n",
      "data_seed=None,\r\n",
      "dataloader_drop_last=False,\r\n",
      "dataloader_num_workers=0,\r\n",
      "dataloader_persistent_workers=False,\r\n",
      "dataloader_pin_memory=True,\r\n",
      "ddp_backend=None,\r\n",
      "ddp_broadcast_buffers=None,\r\n",
      "ddp_bucket_cap_mb=None,\r\n",
      "ddp_find_unused_parameters=None,\r\n",
      "ddp_timeout=1800,\r\n",
      "debug=[],\r\n",
      "deepspeed=None,\r\n",
      "disable_tqdm=False,\r\n",
      "dispatch_batches=None,\r\n",
      "do_eval=False,\r\n",
      "do_predict=False,\r\n",
      "do_train=False,\r\n",
      "eval_accumulation_steps=None,\r\n",
      "eval_delay=0,\r\n",
      "eval_steps=None,\r\n",
      "evaluation_strategy=no,\r\n",
      "fp16=False,\r\n",
      "fp16_backend=auto,\r\n",
      "fp16_full_eval=False,\r\n",
      "fp16_opt_level=O1,\r\n",
      "fsdp=[],\r\n",
      "fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},\r\n",
      "fsdp_min_num_params=0,\r\n",
      "fsdp_transformer_layer_cls_to_wrap=None,\r\n",
      "full_determinism=False,\r\n",
      "generation_config=None,\r\n",
      "generation_max_length=None,\r\n",
      "generation_num_beams=None,\r\n",
      "gradient_accumulation_steps=2,\r\n",
      "gradient_checkpointing=False,\r\n",
      "gradient_checkpointing_kwargs=None,\r\n",
      "greater_is_better=None,\r\n",
      "group_by_length=False,\r\n",
      "half_precision_backend=auto,\r\n",
      "hub_always_push=False,\r\n",
      "hub_model_id=None,\r\n",
      "hub_private_repo=False,\r\n",
      "hub_strategy=every_save,\r\n",
      "hub_token=<HUB_TOKEN>,\r\n",
      "ignore_data_skip=False,\r\n",
      "include_inputs_for_metrics=False,\r\n",
      "include_num_input_tokens_seen=False,\r\n",
      "include_tokens_per_second=False,\r\n",
      "jit_mode_eval=False,\r\n",
      "label_names=None,\r\n",
      "label_smoothing_factor=0.0,\r\n",
      "learning_rate=2e-05,\r\n",
      "length_column_name=length,\r\n",
      "load_best_model_at_end=False,\r\n",
      "local_rank=0,\r\n",
      "log_level=passive,\r\n",
      "log_level_replica=warning,\r\n",
      "log_on_each_node=True,\r\n",
      "logging_dir=output/news-20231223-130659-2e-05/runs/Dec23_13-07-03_ln01,\r\n",
      "logging_first_step=False,\r\n",
      "logging_nan_inf_filter=True,\r\n",
      "logging_steps=1.0,\r\n",
      "logging_strategy=steps,\r\n",
      "lr_scheduler_kwargs={},\r\n",
      "lr_scheduler_type=linear,\r\n",
      "max_grad_norm=1.0,\r\n",
      "max_steps=300,\r\n",
      "metric_for_best_model=None,\r\n",
      "mp_parameters=,\r\n",
      "neftune_noise_alpha=None,\r\n",
      "no_cuda=False,\r\n",
      "num_train_epochs=3.0,\r\n",
      "optim=adamw_torch,\r\n",
      "optim_args=None,\r\n",
      "output_dir=output/news-20231223-130659-2e-05,\r\n",
      "overwrite_output_dir=False,\r\n",
      "past_index=-1,\r\n",
      "per_device_eval_batch_size=8,\r\n",
      "per_device_train_batch_size=1,\r\n",
      "predict_with_generate=False,\r\n",
      "prediction_loss_only=False,\r\n",
      "push_to_hub=False,\r\n",
      "push_to_hub_model_id=None,\r\n",
      "push_to_hub_organization=None,\r\n",
      "push_to_hub_token=<PUSH_TO_HUB_TOKEN>,\r\n",
      "ray_scope=last,\r\n",
      "remove_unused_columns=True,\r\n",
      "report_to=[],\r\n",
      "resume_from_checkpoint=None,\r\n",
      "run_name=output/news-20231223-130659-2e-05,\r\n",
      "save_on_each_node=False,\r\n",
      "save_only_model=False,\r\n",
      "save_safetensors=True,\r\n",
      "save_steps=100,\r\n",
      "save_strategy=steps,\r\n",
      "save_total_limit=None,\r\n",
      "seed=42,\r\n",
      "skip_memory_metrics=True,\r\n",
      "sortish_sampler=False,\r\n",
      "split_batches=False,\r\n",
      "tf32=None,\r\n",
      "torch_compile=False,\r\n",
      "torch_compile_backend=None,\r\n",
      "torch_compile_mode=None,\r\n",
      "torchdynamo=None,\r\n",
      "tpu_metrics_debug=False,\r\n",
      "tpu_num_cores=None,\r\n",
      "use_cpu=False,\r\n",
      "use_ipex=False,\r\n",
      "use_legacy_prediction_loop=False,\r\n",
      "use_mps_device=False,\r\n",
      "warmup_ratio=0.0,\r\n",
      "warmup_steps=0,\r\n",
      "weight_decay=0.0,\r\n",
      ")\r\n",
      "[INFO|tokenization_utils_base.py:2024] 2023-12-23 13:07:03,740 >> loading file tokenizer.model\r\n",
      "[INFO|tokenization_utils_base.py:2024] 2023-12-23 13:07:03,740 >> loading file added_tokens.json\r\n",
      "[INFO|tokenization_utils_base.py:2024] 2023-12-23 13:07:03,740 >> loading file special_tokens_map.json\r\n",
      "[INFO|tokenization_utils_base.py:2024] 2023-12-23 13:07:03,740 >> loading file tokenizer_config.json\r\n",
      "[INFO|tokenization_utils_base.py:2024] 2023-12-23 13:07:03,740 >> loading file tokenizer.json\r\n",
      "[INFO|configuration_utils.py:737] 2023-12-23 13:07:03,836 >> loading configuration file /share/home/zyx/Models/chatglm3-6b-base/config.json\r\n",
      "[INFO|configuration_utils.py:737] 2023-12-23 13:07:03,837 >> loading configuration file /share/home/zyx/Models/chatglm3-6b-base/config.json\r\n",
      "[INFO|configuration_utils.py:802] 2023-12-23 13:07:03,838 >> Model config ChatGLMConfig {\r\n",
      "  \"_name_or_path\": \"/share/home/zyx/Models/chatglm3-6b-base\",\r\n",
      "  \"add_bias_linear\": false,\r\n",
      "  \"add_qkv_bias\": true,\r\n",
      "  \"apply_query_key_layer_scaling\": true,\r\n",
      "  \"apply_residual_connection_post_layernorm\": false,\r\n",
      "  \"architectures\": [\r\n",
      "    \"ChatGLMModel\"\r\n",
      "  ],\r\n",
      "  \"attention_dropout\": 0.0,\r\n",
      "  \"attention_softmax_in_fp32\": true,\r\n",
      "  \"auto_map\": {\r\n",
      "    \"AutoConfig\": \"configuration_chatglm.ChatGLMConfig\",\r\n",
      "    \"AutoModel\": \"modeling_chatglm.ChatGLMForConditionalGeneration\",\r\n",
      "    \"AutoModelForCausalLM\": \"modeling_chatglm.ChatGLMForConditionalGeneration\",\r\n",
      "    \"AutoModelForSeq2SeqLM\": \"modeling_chatglm.ChatGLMForConditionalGeneration\",\r\n",
      "    \"AutoModelForSequenceClassification\": \"modeling_chatglm.ChatGLMForSequenceClassification\"\r\n",
      "  },\r\n",
      "  \"bias_dropout_fusion\": true,\r\n",
      "  \"classifier_dropout\": null,\r\n",
      "  \"eos_token_id\": 2,\r\n",
      "  \"ffn_hidden_size\": 13696,\r\n",
      "  \"fp32_residual_connection\": false,\r\n",
      "  \"hidden_dropout\": 0.0,\r\n",
      "  \"hidden_size\": 4096,\r\n",
      "  \"kv_channels\": 128,\r\n",
      "  \"layernorm_epsilon\": 1e-05,\r\n",
      "  \"model_type\": \"chatglm\",\r\n",
      "  \"multi_query_attention\": true,\r\n",
      "  \"multi_query_group_num\": 2,\r\n",
      "  \"num_attention_heads\": 32,\r\n",
      "  \"num_layers\": 28,\r\n",
      "  \"original_rope\": true,\r\n",
      "  \"pad_token_id\": 0,\r\n",
      "  \"padded_vocab_size\": 65024,\r\n",
      "  \"post_layer_norm\": true,\r\n",
      "  \"pre_seq_len\": null,\r\n",
      "  \"prefix_projection\": false,\r\n",
      "  \"quantization_bit\": 0,\r\n",
      "  \"rmsnorm\": true,\r\n",
      "  \"seq_length\": 32768,\r\n",
      "  \"tie_word_embeddings\": false,\r\n",
      "  \"torch_dtype\": \"float16\",\r\n",
      "  \"transformers_version\": \"4.36.2\",\r\n",
      "  \"use_cache\": true,\r\n",
      "  \"vocab_size\": 65024\r\n",
      "}\r\n",
      "\r\n",
      "[INFO|modeling_utils.py:3341] 2023-12-23 13:07:03,906 >> loading weights file /share/home/zyx/Models/chatglm3-6b-base/pytorch_model.bin.index.json\r\n",
      "[INFO|configuration_utils.py:826] 2023-12-23 13:07:03,907 >> Generate config GenerationConfig {\r\n",
      "  \"eos_token_id\": 2,\r\n",
      "  \"pad_token_id\": 0\r\n",
      "}\r\n",
      "\r\n",
      "Loading checkpoint shards:  43%|███████▋          | 3/7 [00:03<00:04,  1.04s/it]^C\r\n",
      "[2023-12-23 13:07:07,208] torch.distributed.elastic.agent.server.api: [WARNING] Received Signals.SIGINT death signal, shutting down workers\r\n",
      "[2023-12-23 13:07:07,209] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2526171 closing signal SIGINT\r\n",
      "Loading checkpoint shards:  43%|███████▋          | 3/7 [00:03<00:04,  1.10s/it]\r\n",
      "Traceback (most recent call last):\r\n",
      "  File \"/share/home/zyx/Code/ChatGLM3/cookbook/../finetune_basemodel_demo/finetune.py\", line 165, in <module>\r\n",
      "    main()\r\n",
      "  File \"/share/home/zyx/Code/ChatGLM3/cookbook/../finetune_basemodel_demo/finetune.py\", line 93, in main\r\n",
      "    model = AutoModel.from_pretrained(model_args.model_name_or_path, trust_remote_code=True).half().cuda()\r\n",
      "  File \"/share/home/zyx/.conda/envs/chatglm3/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py\", line 561, in from_pretrained\r\n",
      "    return model_class.from_pretrained(\r\n",
      "  File \"/share/home/zyx/.conda/envs/chatglm3/lib/python3.10/site-packages/transformers/modeling_utils.py\", line 3706, in from_pretrained\r\n",
      "    ) = cls._load_pretrained_model(\r\n",
      "  File \"/share/home/zyx/.conda/envs/chatglm3/lib/python3.10/site-packages/transformers/modeling_utils.py\", line 4091, in _load_pretrained_model\r\n",
      "    state_dict = load_state_dict(shard_file)\r\n",
      "  File \"/share/home/zyx/.conda/envs/chatglm3/lib/python3.10/site-packages/transformers/modeling_utils.py\", line 519, in load_state_dict\r\n",
      "    return torch.load(checkpoint_file, map_location=map_location)\r\n",
      "  File \"/share/home/zyx/.conda/envs/chatglm3/lib/python3.10/site-packages/torch/serialization.py\", line 1014, in load\r\n",
      "    return _load(opened_zipfile,\r\n",
      "  File \"/share/home/zyx/.conda/envs/chatglm3/lib/python3.10/site-packages/torch/serialization.py\", line 1422, in _load\r\n",
      "    result = unpickler.load()\r\n",
      "  File \"/share/home/zyx/.conda/envs/chatglm3/lib/python3.10/site-packages/torch/serialization.py\", line 1392, in persistent_load\r\n",
      "    typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))\r\n",
      "  File \"/share/home/zyx/.conda/envs/chatglm3/lib/python3.10/site-packages/torch/serialization.py\", line 1357, in load_tensor\r\n",
      "    storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage\r\n",
      "KeyboardInterrupt\r\n"
     ]
    }
   ],
   "source": [
    "!which python\n",
    "import os\n",
    "from datetime import datetime\n",
    "import random\n",
    "\n",
    "# 定义变量\n",
    "lr = 2e-5\n",
    "num_gpus = 1\n",
    "lora_rank = 8\n",
    "lora_alpha = 32\n",
    "lora_dropout = 0.1\n",
    "max_source_len = 512\n",
    "max_target_len = 128\n",
    "dev_batch_size = 1\n",
    "grad_accumularion_steps = 2\n",
    "max_step = 300\n",
    "save_interval = 100\n",
    "max_seq_len = 512\n",
    "logging_steps=1\n",
    "\n",
    "run_name = \"news\"\n",
    "dataset_path = \"data/toutiao_cat_data_example.jsonl\"\n",
    "model_path = \"/share/home/zyx/Models/chatglm3-6b-base\"\n",
    "datestr = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
    "output_dir = f\"output/{run_name}-{datestr}-{lr}\"\n",
    "master_port = random.randint(10000, 65535)\n",
    "\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "# 构建命令\n",
    "command = f\"\"\"\n",
    "/share/home/zyx/.conda/envs/chatglm3/bin/torchrun --standalone --nnodes=1 --nproc_per_node={num_gpus} ../finetune_basemodel_demo/finetune.py \\\n",
    "    --train_format input-output \\\n",
    "    --train_file {dataset_path} \\\n",
    "    --lora_rank {lora_rank} \\\n",
    "    --lora_alpha {lora_alpha} \\\n",
    "    --lora_dropout {lora_dropout} \\\n",
    "    --max_seq_length {max_seq_len} \\\n",
    "    --preprocessing_num_workers 1 \\\n",
    "    --model_name_or_path {model_path} \\\n",
    "    --output_dir {output_dir} \\\n",
    "    --per_device_train_batch_size 1 \\\n",
    "    --gradient_accumulation_steps 2 \\\n",
    "    --max_steps {max_step} \\\n",
    "    --logging_steps {logging_steps} \\\n",
    "    --save_steps {save_interval} \\\n",
    "    --learning_rate {lr}\n",
    "\"\"\"\n",
    "\n",
    "# 在 Notebook 中执行命令\n",
    "!{command}"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-12-23T05:07:07.654081Z",
     "start_time": "2023-12-23T05:06:59.734414Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 使用微调的模型进行推理预测\n",
    "现在，我们已经完成了模型的微调，接下来，我们将使用微调后的模型进行推理。我们使用与微调时相同的提示词，并使用一些没有出现的模型效果来复现推理结果。"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoModel, AutoTokenizer\n",
    "import os\n",
    "from peft import get_peft_model, LoraConfig, TaskType\n",
    "\n",
    "# 参数设置\n",
    "lora_path = output_dir + \"pytorch_model.bin\"\n",
    "\n",
    "# 加载分词器和模型\n",
    "tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)\n",
    "model = AutoModel.from_pretrained(model_path, load_in_8bit=False, trust_remote_code=True).to(device)\n",
    "\n",
    "# LoRA 模型配置\n",
    "peft_config = LoraConfig(\n",
    "    task_type=TaskType.CAUSAL_LM, inference_mode=True,\n",
    "    target_modules=['query_key_value'],\n",
    "    r=8, lora_alpha=32, lora_dropout=0.1\n",
    ")\n",
    "model = get_peft_model(model, peft_config)\n",
    "if os.path.exists(lora_path):\n",
    "    model.load_state_dict(torch.load(lora_path), strict=False)\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "我们使用同样的提示词进行推理。"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "inputs = tokenizer(PROMPT, return_tensors=\"pt\").to(device)\n",
    "response = model.generate(input_ids=inputs[\"input_ids\"],max_new_tokens=max_new_tokens,temperature=temperature,top_p=top_p,do_sample=True)\n",
    "response = response[0, inputs[\"input_ids\"].shape[-1]:]\n",
    "response = tokenizer.decode(response, skip_special_tokens=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "response"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "这一次，模型成功给出了理想的答案。我们结束实操训练，删除模型并释放显存。"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "del model\n",
    "torch.cuda.empty_cache()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 总结\n",
    "在本实践手册中，我们让开发者体验了 `ChatGLM3-6B` 模型在经过微调前后在 `新闻标题分类` 任务中表现。\n",
    "我们可以发现：\n",
    "\n",
    "在具有混淆的新闻分类中，原始的模型可能受到了误导，不能有效的进行分类，而经过简单微调后的模型，已经具备了正确分类的能力。\n",
    "因此，对于有更高要求的专业分类任务，我们可以使用微调的方式对模型进行简单微调，实现更好的任务完成效果。\n"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
